Code
library(tidyverse, verbose = FALSE)
library(tidymodels, verbose = FALSE)
library(reticulate)
library(ggplot2)
library(plotly)
library(RColorBrewer)
library(bslib)
library(Metrics)
reticulate::use_virtualenv("r-tf")Simone Brazzi
August 2, 2024
In prediction time, il modello deve ritornare un vettore contenente un 1 o uno 0 in corrispondenza di ogni label presente nel dataset (toxic, severe_toxic, obscene, threat, insult, identity_hate). In questo modo, un commento non dannoso sarà classificato da un vettore di soli 0 [0,0,0,0,0,0]. Al contrario, un commento pericoloso presenterà almeno un 1 tra le 6 labels.
Leveraging Quarto and RStudio, I will setup an R and Python enviroment.
Import R libraries. These will be used for both the rendering of the document and data analysis. The reason is I prefer ggplot2 over matplotlib. I will also use colorblind safe palettes.
import pandas as pd
import numpy as np
import matplotlib.pyplot as plt
import tensorflow as tf
import keras
import keras_nlp
from keras.backend import clear_session
from keras.models import Model, load_model
from keras.layers import TextVectorization, Input, Dense, Embedding, Dropout, GlobalAveragePooling1D, LSTM, Bidirectional, GlobalMaxPool1D, Flatten, Attention
from keras.metrics import Precision, Recall, AUC, SensitivityAtSpecificity, SpecificityAtSensitivity, F1Score
from sklearn.model_selection import train_test_split, KFold
from sklearn.metrics import multilabel_confusion_matrix, classification_report, ConfusionMatrixDisplay, precision_recall_curve, f1_score, recall_score, roc_auc_scoreCreate a Config class to store all the useful parameters for the model and for the project.
I created a class with all the basic configuration of the model, to improve the readability.
class Config():
def __init__(self):
self.url = "https://s3.eu-west-3.amazonaws.com/profession.ai/datasets/Filter_Toxic_Comments_dataset.csv"
self.max_tokens = 20000
self.output_sequence_length = 911 # check the analysis done to establish this value
self.embedding_dim = 128
self.batch_size = 32
self.epochs = 100
self.temp_split = 0.3
self.test_split = 0.5
self.random_state = 42
self.total_samples = 159571 # total train samples
self.train_samples = 111699
self.val_samples = 23936
self.features = 'comment_text'
self.labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate']
self.new_labels = ['toxic', 'severe_toxic', 'obscene', 'threat', 'insult', 'identity_hate', "clean"]
self.label_mapping = {label: i for i, label in enumerate(self.labels)}
self.new_label_mapping = {label: i for i, label in enumerate(self.labels)}
self.path = "/Users/simonebrazzi/R/blog/posts/toxic_comment_filter/history/f1score/"
self.model = self.path + "model_f1.keras"
self.checkpoint = self.path + "checkpoint.lstm_model_f1.keras"
self.history = self.path + "lstm_model_f1.xlsx"
self.metrics = [
Precision(name='precision'),
Recall(name='recall'),
AUC(name='auc', multi_label=True, num_labels=len(self.labels)),
F1Score(name="f1", average="macro")
]
def get_early_stopping(self):
early_stopping = keras.callbacks.EarlyStopping(
monitor="val_f1", # "val_recall",
min_delta=0.2,
patience=10,
verbose=0,
mode="max",
restore_best_weights=True,
start_from_epoch=3
)
return early_stopping
def get_model_checkpoint(self, filepath):
model_checkpoint = keras.callbacks.ModelCheckpoint(
filepath=filepath,
monitor="val_f1", # "val_recall",
verbose=0,
save_best_only=True,
save_weights_only=False,
mode="max",
save_freq="epoch"
)
return model_checkpoint
def find_optimal_threshold_cv(self, ytrue, yproba, metric, thresholds=np.arange(.05, .35, .05), n_splits=7):
# instantiate KFold
kf = KFold(n_splits=n_splits, shuffle=True, random_state=42)
threshold_scores = []
for threshold in thresholds:
cv_scores = []
for train_index, val_index in kf.split(ytrue):
ytrue_val = ytrue[val_index]
yproba_val = yproba[val_index]
ypred_val = (yproba_val >= threshold).astype(int)
score = metric(ytrue_val, ypred_val, average="macro")
cv_scores.append(score)
mean_score = np.mean(cv_scores)
threshold_scores.append((threshold, mean_score))
# Find the threshold with the highest mean score
best_threshold, best_score = max(threshold_scores, key=lambda x: x[1])
return best_threshold, best_score
config = Config()The dataset is accessible using tf.keras.utils.get_file to get the file from the url. N.B. For reproducibility purpose, I also downloaded the dataset. There was time in which the link was not available.
# A tibble: 5 × 8
comment_text toxic severe_toxic obscene threat insult identity_hate
<chr> <dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 "Explanation\nWhy the … 0 0 0 0 0 0
2 "D'aww! He matches thi… 0 0 0 0 0 0
3 "Hey man, I'm really n… 0 0 0 0 0 0
4 "\"\nMore\nI can't mak… 0 0 0 0 0 0
5 "You, sir, are my hero… 0 0 0 0 0 0
# ℹ 1 more variable: sum_injurious <dbl>
Lets create a clean variable for EDA purpose: I want to visually see how many observation are clean vs the others labels.
First a check on the dataset to find possible missing values and imbalances.
library(reticulate)
df_r <- py$df
new_labels_r <- py$config$new_labels
df_r_grouped <- df_r %>%
select(all_of(new_labels_r)) %>%
pivot_longer(
cols = all_of(new_labels_r),
names_to = "label",
values_to = "value"
) %>%
group_by(label) %>%
summarise(count = sum(value)) %>%
mutate(freq = round(count / sum(count), 4))
df_r_grouped# A tibble: 7 × 3
label count freq
<chr> <dbl> <dbl>
1 clean 143346 0.803
2 identity_hate 1405 0.0079
3 insult 7877 0.0441
4 obscene 8449 0.0473
5 severe_toxic 1595 0.0089
6 threat 478 0.0027
7 toxic 15294 0.0857
library(reticulate)
barchart <- df_r_grouped %>%
ggplot(aes(x = reorder(label, count), y = count, fill = label)) +
geom_col() +
labs(
x = "Labels",
y = "Count"
) +
# sort bars in descending order
scale_x_discrete(limits = df_r_grouped$label[order(df_r_grouped$count, decreasing = TRUE)]) +
scale_fill_brewer(type = "seq", palette = "RdYlBu")
ggplotly(barchart)It is visible how much the dataset in imbalanced. This means it could be useful to check for the class weight and use this argument during the training.
It is clear that most of our text are clean. We are talking about 0.8033 of the observations which are clean. Only 0.1967 are toxic comments.
To convert the text in a useful input for a NN, it is necessary to use a TextVectorization layer. See the Section 4 section.
One of the method is output_sequence_length: to better define it, it is useful to analyze our text length. To simulate what the model we do, we are going to remove the punctuation and the new lines from the comments.
# A tibble: 1 × 6
Min. `1st Qu.` Median Mean `3rd Qu.` Max.
<dbl> <dbl> <dbl> <dbl> <dbl> <dbl>
1 4 91 196 378. 419 5000
library(reticulate)
boxplot <- df_r %>%
mutate(
comment_text_clean = comment_text %>%
tolower() %>%
str_remove_all("[[:punct:]]") %>%
str_replace_all("\n", " "),
text_length = comment_text_clean %>% str_count()
) %>%
# pull(text_length) %>%
ggplot(aes(y = text_length)) +
geom_boxplot() +
theme_minimal()
ggplotly(boxplot)library(reticulate)
df_ <- df_r %>%
mutate(
comment_text_clean = comment_text %>%
tolower() %>%
str_remove_all("[[:punct:]]") %>%
str_replace_all("\n", " "),
text_length = comment_text_clean %>% str_count()
)
Q1 <- quantile(df_$text_length, 0.25)
Q3 <- quantile(df_$text_length, 0.75)
IQR <- Q3 - Q1
upper_fence <- as.integer(Q3 + 1.5 * IQR)
histogram <- df_ %>%
ggplot(aes(x = text_length)) +
geom_histogram(bins = 50) +
geom_vline(aes(xintercept = upper_fence), color = "red", linetype = "dashed", linewidth = 1) +
theme_minimal() +
xlab("Text Length") +
ylab("Frequency") +
xlim(0, max(df_$text_length, upper_fence))
ggplotly(histogram)Considering all the above analysis, I think a good starting value for the output_sequence_length is 911, the upper fence of the boxplot. In the last plot, it is the dashed red vertical line.. Doing so, we are removing the outliers, which are a small part of our dataset.
Now we can split the dataset in 3: train, test and validation sets. Considering there is not a function in sklearn which lets split in these 3 sets, we can do the following: - split between a train and temporary set with a 0.3 split. - split the temporary set in 2 equal sized test and val sets.
x = df[config.features].values
y = df[config.labels].values
xtrain, xtemp, ytrain, ytemp = train_test_split(
x,
y,
test_size=config.temp_split, # .3
random_state=config.random_state
)
xtest, xval, ytest, yval = train_test_split(
xtemp,
ytemp,
test_size=config.test_split, # .5
random_state=config.random_state
)xtrain shape: py$xtrain.shape ytrain shape: py$ytrain.shape xtest shape: py$xtest.shape ytest shape: py$ytest.shape xval shape: py$xval.shape yval shape: py$yval.shape
The datasets are created using the tf.data.Dataset function. It creates a data input pipeline. The tf.data API makes it possible to handle large amounts of data, read from different data formats, and perform complex transformations. The tf.data.Dataset is an abstraction that represents a sequence of elements, in which each element consists of one or more components. Here each dataset is creates using from_tensor_slices. It create a tf.data.Dataset from a tuple (features, labels). .batch let us work in batches to improve performance, while .prefetch overlaps the preprocessing and model execution of a training step. While the model is executing training step s, the input pipeline is reading the data for step s+1. Check the documentation for further informations.
train_ds = (
tf.data.Dataset
.from_tensor_slices((xtrain, ytrain))
.shuffle(xtrain.shape[0])
.batch(config.batch_size)
.prefetch(tf.data.experimental.AUTOTUNE)
)
test_ds = (
tf.data.Dataset
.from_tensor_slices((xtest, ytest))
.batch(config.batch_size)
.prefetch(tf.data.experimental.AUTOTUNE)
)
val_ds = (
tf.data.Dataset
.from_tensor_slices((xval, yval))
.batch(config.batch_size)
.prefetch(tf.data.experimental.AUTOTUNE)
)train_ds cardinality: 3491
val_ds cardinality: 748
test_ds cardinality: 748
Check the first element of the dataset to be sure that the preprocessing is done correctly.
(array([b'and more details on his various positions in the oil industry',
b'"\n\n New Fake ""CryptoLocker"" circulating on the Internet \n\nA few days ago, in the institution where my wife works they catch it with CryptoLocker-like virus that in all external signs resembles normal CryptoLocker virus. I asked her to send me several ""encrypted"" files together with their unencrypted copies to analyze them and the first thing I noticed was the difference in their size. The ""encrypted"" files are smaller than the originals - something that will not happen when encoding with RSA, especially ""2048-RSA"" as in the message for ransom claimed. This ""CryptoLocker"" simply use some sort of data compression like in RAR format. \n\nDoes anyone have more information about this virus? "',
b'"\nHi Yemi.wikis, and Welcome to Wikipedia! \nWelcome to Wikipedia! I hope you enjoy the encyclopedia and want to stay. As a first step, you may wish to read the Introduction.\n\nIf you have any questions, feel free to ask me at my talk page \xe2\x80\x94 I\'m happy to help. Or, you can ask your question at the New contributors\' help page.\n\n \nHere are some more resources to help you as you explore and contribute to the world\'s largest encyclopedia...\n\n Finding your way around: \n\n Table of Contents\n\n Department directory\n\n Need help? \n\n Questions \xe2\x80\x94 a guide on where to ask questions.\n Cheatsheet \xe2\x80\x94 quick reference on Wikipedia\'s mark-up codes.\n\n Wikipedia\'s 5 pillars \xe2\x80\x94 an overview of Wikipedia\'s foundations\n The Simplified Ruleset \xe2\x80\x94 a summary of Wikipedia\'s most important rules.\n\n How you can help: \n\n Contributing to Wikipedia \xe2\x80\x94 a guide on how you can help.\n\n Community Portal \xe2\x80\x94 Wikipedia\'s hub of activity.\n\n Additional tips... \n\n Please sign your messages on talk pages with four tildes (~~~~). This will automatically insert your ""signature"" (your username and a date stamp). The button, on the tool bar above Wikipedia\'s text editing window, also does this. \n\n If you would like to play around with your new Wiki skills the Sandbox is for you. \n\n Good luck, and have fun. "',
b'"ChaJeCraft is a gmod,terraria and minecraft based server!\nFirst is was like ""OMG I MADED A SERVER YAY""!\n\nbut now its more like my hobby!\n\nWanna go play it???\n\nWanna meet Jensen,Lousisx and Chaline\n\nJust go to www.chajecraft.jimdo.com\n\nand www.chajecraft.actieforum.com\n\nThere you find ip,info and more!\n\nhere is a version history\n\n Chajecraft test version 1 \n\n Server test version released\n Hamachi\n ip:unknown\n Map called test\n Chajecraft test version 2 \n\n Server software ""MCadmin"" used\n More commands\n Chajecraft test version 3 \n \n\n Back to old server software\n Chajecraft test version 4 \n\n Server software ""canary"" used\n Plugins used\n Beginning to develope my own ranks\n Site started\n Motd\n more that i didnt remember\n Chajecraft test version 5 \n\n Added new ranks such as ""Co-owner,Owner-developer,Owner-builder,member,builder""\n Guests cant build anymore\n Chajecraft Version 1.0 \n\n non-hamachi server\n new ip:94.227.108.96\n more to come!\n Chajecraft Ranks and hosts \nyou may be looking for this page Hosts and Ranks"',
b'"\n Maybe the ""h"" was an accident?? And I haven\'t seen anything about the ""Urbanski"" either. I am about 90% sure he was born with the name Keith Lionel Urban, or he would have made reference to it somewhere. Or maybe I don\'t know him well enough and it is. Keith\'s Fan, Danielle"',
b'Help\n\nHow can I have the alt text changed on the wiki logo on the top left?',
b'The French and Indian War was not fought between the French and the Indians. It was fought between the French and the British. The Indians were on the side of the French. Well, most of the Indians were anyway. 2:19, 21 July 2004 (UTC)',
b"My memory, such as it is, tells me it was from a biography of Perkins, written when he was much in the limelight, that I was flipping through in a bookshop shortly after it was published. I have no memory of the author's name, but it seemed to be a substantial and well-written book. It may have been the Peter Read book, but I couldn't swear to it. I'll see if I can track this information down.",
b"Point Roberts \n\nHey there, I saw you did a bunch of work on this article, and was curious, if you live there, if you'd be able to get a picture of the Boundary Marker No. 1 at the Marine Drive crossing for the List of Registered Historic Places in Washington. Cheers!",
b', 12 May 2008 (UTC)\nBy the standard of other articles it is a minor infraction, and the image use is notably better than most televisual articles 13:45',
b"The discussion page of FF speaks for itself as it is one long personal attack and incivility from start to go. Hardly the point, when I initiated the discussion focusing on NPOV and sources, reverted (rather than edit war). Furthermore my last edit included much of the latest attacks (linked), which my edit adressed. If you were interested, you would have checked, rather than just focus on me while their attacks continue. We both know, my providing diffs won't matter when certain users are granted immunity. I followed all protocols and even tried to address admin, when I saw the discussion was sidetracking and going in circles. In vain. But I tried yet I am blocked. So we are through.",
b'Lex94, stop with the Lashley bull. Here is a list of the TCC titles.\n\nWWE Championship\nWorld Heavyweight Championship\nWWE Tag Team Championship\nWorld Tag Team Championship\nIC title.\n\nHere are the titles Lashley has won,\n\nECW title\nUS title.\n\nAre either of those on the list above? If the answer is no then the position is clear, refrain from filling talk pages with rubbish about how Lashley is a potential or unofficial champion. And sign you posts.',
b'":Then they would have came from the Central Rada, the Hetmanate, the Directorate, the Ukrainian People\'s Republic or the West Ukrainian People\'s Republic. Either way it\'s deprecated usage that originally contained ""the"". Regardless people usually refer to it in its present form (much like Beijing is now used instead of Peking) and proper English declares it without \'the\'. 8:31, 18 June 2007 (UTC)\n\n"',
b'Moslims are filthy thugs who do not belong on this planet.',
b"Typical - I'll pass this along.",
b'"\n\n AIV Reporting Bug \n\nYeah, I just noticed that too. I am using Vandal Proof and it just does that. I am also annoyed that Vandal Proof does not include an edit summary when reporting users on AIV. I just made a request on the request page for the addition of an edit summary. Now, I will report a bug with this weird % thing. My apologies, I hope it gets fixed soon. Cheers! e2221 "',
b'"\n\n""...who scored the biggest upsets in the history of boxing, perhaps the biggest upset in sports history..."" ... I think that needs a bit of an edit, not very objective.\n\n Wadey4"',
b'Arbitration: Konullu \n\nYou have been mentioned here. Wikipedia:Arbitration/Requests/Enforcement',
b"Sign your message with four tildes (~). It's not that hard.",
b"Aachen \n\nHello, I'm wondering what relics there are of the Frankish Empire. What architecture and museum-level pieces remain? I understand that Charlemagne's tomb lies in his cathedral, but what about the palaces? Where are those?",
b'What do you think Cuba is?',
b'"\n Thanks for the pointer. Both the pages (:Category:British politicians by nationality and Category:British writers by location) were tagged as speedy, but I can find no trace of either ever having been listed at WP:CFD/S. (using Wikiblame: )\nThere is no way that either meets any speedy criteria, so I have simply removed the tags. (talk) \xe2\x80\xa2 (contribs) "',
b"Page moves\nHi there. For future reference, please don't move pages by copying and pasting, as you did with Bill Garner. Rather, use the move tab at the top of the page, as it preserves the article's edit history. Thanks.",
b'"::::There is now documentation concerning Johnson\'s sexual orientation in a biography published by a reliable academic press. This is discussed in the article under ""personal life"". An attempt was just made to delete this reference because, if it had been true, according to the deleter, it would have been reported in the National Enquirer. This is absurd. Using non-existent NE article is even less reliable than an actual NE article. \n\n"',
b'"\nAnd I didn\'t even add that one! You know, I think it\'s official: You are having too damn much fun with that article! (But that\'s a good thing!) (talk) "',
b'"\n\nThe first image ""Turkish Cypriots in Britain"" is inappropriate and features Turkish nationalists carrying flags of Turkey and the Turkish occupation regime, (and not one sight of a flag of the internationally recognised Republic of Cyprus) in a provocative anti-Greek Cypriot rally in support of the illegal Turkish occupation of Cyprus. The image also violates the rules on use of national flags in articles about ethnic minorities living in another country. "',
b'Rather than get all bent out of shape on the AfD page, why not read up on our standards?',
b"January 2009 \n Please stop your disruptive editing. If you continue to vandalize Wikipedia, you will be blocked from editing. \nIf this is a shared IP address, and you didn't make the edit, consider creating an account for yourself so you can avoid further irrelevant notices.",
b'Alphabetize track list \n\nThe track list should probably be alphabetized. The order is seemingly random at the moment, and we have no idea about relative difficulty (and hence set number) of songs.',
b'Thank you for experimenting with the page Guitar on Wikipedia. Your test worked, and has been reverted or removed. Please use the sandbox for any other tests you want to do. Take a look at the welcome page if you would like to learn more about contributing to our encyclopedia. A link to the edit I have reverted can be found here: link. If you believe this edit should not have been reverted, please contact me.',
b"Motion for deletion. \n\nThis has to be one of the stupidest articles on wikipedia. There's a single arch McDonalds a few blocks from where I live. You want to do an article about that too?",
b"Just to clarify why you posted this - was it because your proposing to have it added to the article or just because you wanted to post this to anger people? It really has no place on this page and I wouldn't be surprised if it's deleted."],
dtype=object), array([[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[1, 0, 1, 0, 1, 1],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0],
[0, 0, 0, 0, 0, 0]]))
And we check also the shape. We expect a feature of shape (batch, ) and a target of shape (batch, number of labels).
text train shape: (32,)
text train type: object
label train shape: (32, 6)
label train type: int64
Of course preprocessing! Text is not the type of input a NN can handle. The TextVectorization layer is meant to handle natural language inputs. The processing of each example contains the following steps: 1. Standardize each example (usually lowercasing + punctuation stripping) 2. Split each example into substrings (usually words) 3. Recombine substrings into tokens (usually ngrams) 4. Index tokens (associate a unique int value with each token) 5. Transform each example using this index, either into a vector of ints or a dense float vector.
For more reference, see the documentation at the following link.
text_vectorization = TextVectorization(
max_tokens=config.max_tokens,
standardize="lower_and_strip_punctuation",
split="whitespace",
output_mode="int",
output_sequence_length=config.output_sequence_length,
pad_to_max_tokens=True
)
# prepare a dataset that only yields raw text inputs (no labels)
text_train_ds = train_ds.map(lambda x, y: x)
# adapt the text vectorization layer to the text data to index the dataset vocabulary
text_vectorization.adapt(text_train_ds)This layer is set to: - max_tokens: 20000. It is common for text classification. It is the maximum size of the vocabulary for this layer. - output_sequence_length: 911. See Figure 3 for the reason why. Only valid in "int" mode. - output_mode: outputs integer indices, one integer index per split string token. When output_mode == “int”, 0 is reserved for masked locations; this reduces the vocab size to max_tokens - 2 instead of max_tokens - 1. - standardize: "lower_and_strip_punctuation". - split: on whitespace.
To preserve the original comments as text and also have a tf.data.Dataset in which the text is preprocessed by the TextVectorization function, it is possible to map it to the features of each dataset.
processed_train_ds = train_ds.map(
lambda x, y: (text_vectorization(x), y),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_val_ds = val_ds.map(
lambda x, y: (text_vectorization(x), y),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)
processed_test_ds = test_ds.map(
lambda x, y: (text_vectorization(x), y),
num_parallel_calls=tf.data.experimental.AUTOTUNE
)Define the model using the Functional API.
def get_deeper_lstm_model():
clear_session()
inputs = Input(shape=(None,), dtype=tf.int64, name="inputs")
embedding = Embedding(
input_dim=config.max_tokens,
output_dim=config.embedding_dim,
mask_zero=True,
name="embedding"
)(inputs)
x = Bidirectional(LSTM(256, return_sequences=True, name="bilstm_1"))(embedding)
x = Bidirectional(LSTM(128, return_sequences=True, name="bilstm_2"))(x)
# Global average pooling
x = GlobalAveragePooling1D()(x)
# Add regularization
x = Dropout(0.3)(x)
x = Dense(64, activation='relu', kernel_regularizer=tf.keras.regularizers.l2(0.01))(x)
x = LayerNormalization()(x)
outputs = Dense(len(config.labels), activation='sigmoid', name="outputs")(x)
model = Model(inputs, outputs)
model.compile(optimizer='adam', loss="binary_crossentropy", metrics=config.metrics, steps_per_execution=32)
return model
lstm_model = get_deeper_lstm_model()
lstm_model.summary()Finally, the model has been trained using 2 callbacks: - Early Stopping, to avoid to consume the kaggle GPU time. - Model Checkpoint, to retrieve the best model training information.
Considering the dataset is imbalanced, to increase the performance we need to calculate the class weight. This will be passed during the training of the model.
class_weight
toxic 0.095900590
severe_toxic 0.009928468
obscene 0.052757858
threat 0.003061800
insult 0.049132042
identity_hate 0.008710911
It is also useful to define the steps per epoch for train and validation dataset. This step is required to avoid to not consume entirely the dataset during the fit, which happened to me.
The fit has been done on Kaggle to levarage the GPU. Some considerations about the model:
.repeat() ensure the model sees all the dataset.epocs is set to 100.validation_data has the same repeat.callbacks are the one defined before.class_weight ensure the model is trained using the frequency of each class, because our dataset is imbalanced.steps_per_epoch and validation_steps depend on the use of repeat.Now we can import the model and the history trained on Kaggle.
# A tibble: 5 × 2
metric value
<chr> <dbl>
1 loss 0.0542
2 precision 0.789
3 recall 0.671
4 auc 0.957
5 f1_score 0.0293
For the prediction, the model does not need to repeat the dataset, because it has already been trained on all of the train data. Now it has just to consume the new data to make the prediction.
The best way to assess the performance of a multi label classification is using a confusion matrix. Sklearn has a specific function to create a multi label classification matrix to handle the fact that there could be multiple labels for one prediction.
Grid Search CV is a technique for fine-tuning hyperparameter of a ML model. It systematically search through a set of hyperparamenter values to find the combination which led to the best model performance. In this case, I am using a KFold Cross Validation is a resempling technique to split the data into k consecutive folds. Each fold is used once as a validation while the k - 1 remaining folds are the training set. See the documentation for more information.
The model is trained to optimize the recall. The decision was made because the cost of missing a True Positive is greater than a False Positive. In this case, missing a injurious observation is worst than classifying a clean one as bad.
Whilst the KFold GDCV technique is usefull to test multiple hyperparameter, it is important to understand the problem we are facing. A multi label deep learning classifier outputs a vector of per-class probabilities. These need to be converted to a binary vector using a confidence threshold.
Threshold selection mean we have to decide which metric to prioritize, based on the problem we are facing and the relative cost of misduging. We can consider the toxic comment filtering a problem similiar to cancer diagnostic. It is better to predict cancer in people who do not have it [False Positive] and perform further analysis than do not predict cancer when the patient has the disease [False Negative].
I decide to train the model on the F1 score to have a balanced model in both precision and recall and leave to the threshold selection to increase the recall performance.
Moreover, the model has been trained on the macro avarage F1 score, which is a single performance indicator obtained by the mean of the Precision and Recall scores of individual classses.
It is usegule for imbalanced classes, because it weights each classes equally. It is not influenced by the number of samples of each classes. This is sette both in the config.metrics and find_optimal_threshold_cv.
Optimal threshold: 0.15000000000000002
Best score: 0.4788653077945807
Optimal threshold f1 score: 0.15. Best score: 0.4788653.
Optimal threshold recall: 0.05. Best score: 0.8095814.
Optimal threshold: 0.05
Best score: 0.8809499649742268
Optimal threshold roc: 0.05. Best score: 0.88095.
# convert probability predictions to predictions
ypred = predictions >= optimal_threshold_recall # .05
ypred = ypred.astype(int)
# create a plot with 3 by 2 subplots
fig, axes = plt.subplots(3, 2, figsize=(15, 15))
axes = axes.flatten()
mcm = multilabel_confusion_matrix(ytrue, ypred)
# plot the confusion matrices for each label
for i, (cm, label) in enumerate(zip(mcm, config.labels)):
disp = ConfusionMatrixDisplay(confusion_matrix=cm)
disp.plot(ax=axes[i], colorbar=False)
axes[i].set_title(f"Confusion matrix for label: {label}")
plt.tight_layout()
plt.show()
# A tibble: 10 × 5
metrics precision recall `f1-score` support
<chr> <dbl> <dbl> <dbl> <dbl>
1 toxic 0.552 0.890 0.682 2262
2 severe_toxic 0.236 0.917 0.375 240
3 obscene 0.550 0.936 0.692 1263
4 threat 0.0366 0.493 0.0681 69
5 insult 0.471 0.915 0.622 1170
6 identity_hate 0.116 0.720 0.200 207
7 micro avg 0.416 0.896 0.569 5211
8 macro avg 0.327 0.812 0.440 5211
9 weighted avg 0.495 0.896 0.629 5211
10 samples avg 0.0502 0.0848 0.0597 5211
The BiLSTM model is optimized to have an high recall is performing good enough to make predictions for each label. Considering the low support for the threat label, the performance is not bad. See Table 2 and Figure 1: the threat label is only 0.27 % of the observations. The model has been optimized for recall because the cost of not identifying a injurious comment as such is higher than the cost of considering a clean comment as injurious.
Possibile improvements could be to increase the number of observations, expecially for the threat one. In general there are too many clean comments. This could be avoided doing an undersampling of the clean comment, which I explicitly avoided to check the performance on the BiLSTM with an imbalanced dataset, leveraging the class weight method.